import pytest
import tensorflow as tf
import numpy as np
from tests.attacks.utils import backend_test_classifier_type_check_fail
from art.attacks.evasion.pe_malware_attack import MalwareGDTensorFlow
from art.estimators.classification import TensorFlowV2Classifier

from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin
from art.estimators.classification.classifier import ClassifierMixin

from tests.utils import ARTTestException


@pytest.fixture()
def fix_get_synthetic_data():
    """
    As real malware is hard to share, generate random data of the correct size
    """
    # generate dummy data
    padding_char = 256
    maxlen = 2 ** 20

    # batch of 5 datapoints
    synthetic_data = np.ones((5, maxlen), dtype=np.uint16) * padding_char

    size_of_original_files = [
        int(maxlen * 0.1),  # 1 sample significantly smaller than the maxlen
        int(maxlen * 1.5),  # 1 sample larger then maxlen
        int(maxlen * 0.95),  # 1 sample close to the maximum of the maxlen
        int(maxlen),  # 1 sample at the maxlen
        int(maxlen),
    ]  # 1 sample at the maxlen, this will be assigned a benign label and
    # should not be perturbed by the attack.

    # two class option, later change to binary when ART is generally updated.
    y = np.zeros((5, 1))
    y[0:4] = 1  # assign the first 4 datapoints to be labeled as malware

    # fill in with random numbers
    for i, size in enumerate(size_of_original_files):
        if size > maxlen:
            size = maxlen
        synthetic_data[i, 0:size] = np.random.randint(low=0, high=256, size=(1, size))

    # set the DOS header values:
    synthetic_data[:, 0:2] = [77, 90]
    synthetic_data[:, int(0x3C) : int(0x40)] = 0  # zero the pointer location
    synthetic_data[:, int(0x3C)] = 44  # put in a dummy pointer
    synthetic_data[:, int(0x3C) + 1] = 1  # put in a dummy pointer
    return synthetic_data, y, np.asarray(size_of_original_files)


@pytest.fixture()
def fix_make_dummy_model():
    """
    Create a random model for testing
    """

    def get_prediction_model(param_dic):
        """
        Model going from embeddings to predictions so we can easily optimise the embedding malware embedding.
        Needs to have the same structure as the target model.
        Populated here with "standard" parameters.
        """
        inp = tf.keras.layers.Input(
            shape=(
                param_dic["maxlen"],
                param_dic["embedding_size"],
            )
        )
        filt = tf.keras.layers.Conv1D(
            filters=128,
            kernel_size=500,
            strides=500,
            use_bias=True,
            activation="relu",
            padding="valid",
            name="filt_layer",
        )(inp)
        attn = tf.keras.layers.Conv1D(
            filters=128,
            kernel_size=500,
            strides=500,
            use_bias=True,
            activation="sigmoid",
            padding="valid",
            name="attn_layer",
        )(inp)
        gated = tf.keras.layers.Multiply()([filt, attn])
        feat = tf.keras.layers.GlobalMaxPooling1D()(gated)
        dense = tf.keras.layers.Dense(128, activation="relu", name="dense_layer")(feat)
        output = tf.keras.layers.Dense(1, name="output_layer")(dense)
        return tf.keras.Model(inputs=inp, outputs=output)

    param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
    prediction_model = get_prediction_model(param_dic)

    model_weights = np.random.normal(loc=0, scale=1.0, size=(257, 8))

    classifier = TensorFlowV2Classifier(
        model=prediction_model,
        nb_classes=2,
        loss_object=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        input_shape=(param_dic["maxlen"], param_dic["embedding_size"]),
    )

    return classifier, model_weights


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_no_perturbation(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Assert that with 0 perturbation the data is unmodified
    """
    try:
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        # First check: with no perturbation the malware of sufficient size, and benign files, should be unperturbed
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0], embedding_weights=fix_make_dummy_model[1], l_0=0, param_dic=param_dic
        )

        attack.l_0 = 0
        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        adv_x, adv_y, adv_sizes = attack.pull_out_valid_samples(x, y, size_of_files)

        # We should only have 3 files as the following cannot be converted to valid adv samples:
        #   2nd datapoint (file too large to support any modifications)
        #   5th datapoint (benign file)

        assert len(adv_x) == 3

        adv_x = attack.generate(adv_x, adv_y, adv_sizes)

        j = 0
        for i in range(len(x)):
            if i in [0, 2, 3]:
                assert np.array_equal(adv_x[j], x[i])
                j += 1
            else:
                assert np.array_equal(x[i], fix_get_synthetic_data[0][i])
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_append_attack(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Check append attack wih a given l0 budget
    """
    try:
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        l0_budget = 1250
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0],
            embedding_weights=fix_make_dummy_model[1],
            l_0=l0_budget,
            param_dic=param_dic,
        )

        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        adv_x, adv_y, adv_sizes = attack.pull_out_valid_samples(x, y, size_of_files)

        # We should only have 2 files as the following cannot be converted to valid adv samples:
        #   2nd datapoint (file too large to support any modifications)
        #   4th datapoint (file to large to support append attacks)
        #   5th datapoint (benign file)

        assert len(adv_x) == 2

        adv_x = attack.generate(adv_x, adv_y, adv_sizes)

        j = 0
        for i, size in enumerate(fix_get_synthetic_data[2]):
            if i in [0, 2]:
                assert np.array_equal(adv_x[j, :size], fix_get_synthetic_data[0][i, :size])
                assert not np.array_equal(
                    adv_x[j, size : size + l0_budget], fix_get_synthetic_data[0][i, size : size + l0_budget]
                )
                assert np.array_equal(adv_x[j, size + l0_budget :], fix_get_synthetic_data[0][i, size + l0_budget :])
                j += 1
            else:
                assert np.array_equal(x[i], fix_get_synthetic_data[0][i])
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_slack_attacks(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Testing modification of certain regions in the PE file
    """
    try:
        # Third check: Slack insertion attacks.

        def generate_synthetic_slack_regions(size):
            """
            Generate 4 slack regions per sample, each of size 250.
            """

            batch_of_slack_starts = []
            batch_of_slack_sizes = []

            for _ in range(5):
                size_of_slack = []
                start_of_slack = []
                start = 0
                for _ in range(4):
                    start += 1000
                    start_of_slack.append(start)
                    size_of_slack.append(size)
                batch_of_slack_starts.append(start_of_slack)
                batch_of_slack_sizes.append(size_of_slack)
            return batch_of_slack_starts, batch_of_slack_sizes

        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        # First check: with no perturbation the malware of sufficient size, and benign files, should be unperturbed
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0], embedding_weights=fix_make_dummy_model[1], l_0=0, param_dic=param_dic
        )

        l0_budget = 1250
        attack.l_0 = l0_budget
        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        batch_of_section_starts, batch_of_section_sizes = generate_synthetic_slack_regions(size=250)
        adv_x, adv_y, adv_sizes, batch_of_section_starts, batch_of_section_sizes = attack.pull_out_valid_samples(
            x,
            y,
            sample_sizes=size_of_files,
            perturb_starts=batch_of_section_starts,
            perturb_sizes=batch_of_section_sizes,
        )

        # We should only have 2 files as the following cannot be converted to valid adv samples:
        #   2nd datapoint (file too large to support any modifications)
        #   4th datapoint (attack requires appending 250 bytes to end of file which this datapoint cannot support)
        #   5th datapoint (benign file)
        assert len(adv_x) == 2

        adv_x = attack.generate(
            adv_x, adv_y, adv_sizes, perturb_sizes=batch_of_section_sizes, perturb_starts=batch_of_section_starts
        )

        j = 0
        for i, size in enumerate(fix_get_synthetic_data[2]):
            if i in [0, 2]:
                slack_starts = batch_of_section_starts[j]
                slack_sizes = batch_of_section_sizes[j]
                beginning_pos = 0
                total_perturbation = 0
                for slack_start, slack_size in zip(slack_starts, slack_sizes):
                    assert np.array_equal(
                        adv_x[j, beginning_pos:slack_start], fix_get_synthetic_data[0][i, beginning_pos:slack_start]
                    )

                    assert not np.array_equal(
                        adv_x[j, slack_start : slack_start + slack_size],
                        fix_get_synthetic_data[0][i, slack_start : slack_start + slack_size],
                    )
                    beginning_pos = slack_start + slack_size
                    total_perturbation += slack_size

                # from the last slack region to end of file
                assert np.array_equal(adv_x[j, beginning_pos:size], fix_get_synthetic_data[0][i, beginning_pos:size])
                remaining_perturbation = l0_budget - total_perturbation
                assert remaining_perturbation == 250

                # append portion of the attack has been conducted
                assert not np.array_equal(
                    adv_x[j, size : size + remaining_perturbation],
                    fix_get_synthetic_data[0][i, size : size + remaining_perturbation],
                )

                # from end of append to end of datapoint
                assert np.array_equal(
                    adv_x[j, size + remaining_perturbation :],
                    fix_get_synthetic_data[0][i, size + remaining_perturbation :],
                )

                j += 1
            else:
                assert np.array_equal(x[i], fix_get_synthetic_data[0][i])
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_large_append(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Testing with very large perturbation budgets
    """
    # Fourth check append large perturbation
    try:
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        # First check: with no perturbation the malware of sufficient size, and benign files, should be unperturbed
        l0_budget = int(((2 ** 20) * 0.2))
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0],
            embedding_weights=fix_make_dummy_model[1],
            l_0=l0_budget,
            param_dic=param_dic,
        )

        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        adv_x, adv_y, adv_sizes = attack.pull_out_valid_samples(x, y, sample_sizes=size_of_files)
        # We should only have one datapoint that can support an append perturbation of this size
        assert len(adv_x) == 1

        adv_x = attack.generate(adv_x, adv_y, adv_sizes)

        j = 0
        for i, size in enumerate(fix_get_synthetic_data[2]):
            if i == 0:
                assert np.array_equal(adv_x[j, :size], fix_get_synthetic_data[0][i, :size])
                assert not np.array_equal(
                    adv_x[j, size : size + l0_budget], fix_get_synthetic_data[0][i, size : size + l0_budget]
                )
                assert np.array_equal(adv_x[j, size + l0_budget :], fix_get_synthetic_data[0][i, size + l0_budget :])
            else:
                assert np.array_equal(x[i], fix_get_synthetic_data[0][i])
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_dos_header_attack(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Test the DOS header attack modifies the correct regions
    """
    # 5th check: DOS header attack
    try:
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        # First check: with no perturbation the malware of sufficient size, and benign files, should be unperturbed
        l0_budget = 290
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0],
            embedding_weights=fix_make_dummy_model[1],
            l_0=l0_budget,
            param_dic=param_dic,
        )
        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        dos_starts, dos_sizes = attack.get_dos_locations(x)

        adv_x, adv_y, adv_sizes, batch_of_section_starts, batch_of_section_sizes = attack.pull_out_valid_samples(
            x, y, sample_sizes=size_of_files, perturb_starts=dos_starts, perturb_sizes=dos_sizes
        )

        # should have 3 files. Samples which are excluded are:
        #   2nd datapoint (file to large to support any modifications)
        #   5th datapoint (benign file)

        assert len(adv_x) == 3

        adv_x = attack.generate(
            adv_x, adv_y, adv_sizes, perturb_sizes=batch_of_section_sizes, perturb_starts=batch_of_section_starts
        )

        j = 0
        for i in range(len(fix_get_synthetic_data[2])):
            if i in [0, 2, 3]:
                assert np.array_equal(adv_x[j, 0:2], [77, 90])

                # we should have 58 bytes that were perturbed between the magic number and the pointer
                assert not np.array_equal(adv_x[j, 2 : int(0x3C)], fix_get_synthetic_data[0][i, 2 : int(0x3C)])

                # dummy pointer should be unchanged
                assert np.array_equal(adv_x[j, int(0x3C) : int(0x3C) + 4], [44, 1, 0, 0])

                # the remaining perturbation 290 - 58 = 232 is in the rest of the DOS header
                assert not np.array_equal(
                    adv_x[j, int(0x3C) + 4 : int(0x3C) + 4 + 232],
                    fix_get_synthetic_data[0][i, int(0x3C) + 4 : int(0x3C) + 4 + 232],
                )

                # rest of the file is unchanged
                assert np.array_equal(
                    adv_x[j, int(0x3C) + 4 + 232 :], fix_get_synthetic_data[0][i, int(0x3C) + 4 + 232 :]
                )
                j += 1
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_no_auto_append(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    Verify behaviour when not spilling extra perturbation into an append attack
    """
    try:

        def generate_synthetic_slack_regions(size):
            """
            Generate 4 slack regions per sample, each of size 250.
            """

            batch_of_slack_starts = []
            batch_of_slack_sizes = []

            for _ in range(5):
                size_of_slack = []
                start_of_slack = []
                start = 0
                for _ in range(4):
                    start += 1000
                    start_of_slack.append(start)
                    size_of_slack.append(size)
                batch_of_slack_starts.append(start_of_slack)
                batch_of_slack_sizes.append(size_of_slack)
            return batch_of_slack_starts, batch_of_slack_sizes

        # 6th check: Do not automatically append extra perturbation
        l0_budget = 1250
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0],
            embedding_weights=fix_make_dummy_model[1],
            l_0=l0_budget,
            param_dic=param_dic,
        )
        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        attack.l_0 = l0_budget

        batch_of_section_starts, batch_of_section_sizes = generate_synthetic_slack_regions(size=250)
        adv_x, adv_y, adv_sizes, batch_of_section_starts, batch_of_section_sizes = attack.pull_out_valid_samples(
            x,
            y,
            sample_sizes=size_of_files,
            perturb_starts=batch_of_section_starts,
            perturb_sizes=batch_of_section_sizes,
            automatically_append=False,
        )
        # should have 3 samples.
        # 2nd datapoint cannot support any modification
        # 5th is a benign sample
        assert len(adv_x) == 3

        adv_x = attack.generate(
            adv_x,
            adv_y,
            adv_sizes,
            automatically_append=False,
            perturb_sizes=batch_of_section_sizes,
            perturb_starts=batch_of_section_starts,
        )

        j = 0
        for i in range(len(fix_get_synthetic_data[0])):
            if i in [0, 2, 3]:
                slack_starts = batch_of_section_starts[j]
                slack_sizes = batch_of_section_sizes[j]
                beginning_pos = 0
                total_perturbation = 0
                for slack_start, slack_size in zip(slack_starts, slack_sizes):
                    assert np.array_equal(
                        adv_x[j, beginning_pos:slack_start], fix_get_synthetic_data[0][i, beginning_pos:slack_start]
                    )

                    assert not np.array_equal(
                        adv_x[j, slack_start : slack_start + slack_size],
                        fix_get_synthetic_data[0][i, slack_start : slack_start + slack_size],
                    )
                    # set the position to the end of the slack region
                    beginning_pos = slack_start + slack_size
                    total_perturbation += slack_size
                # from end of final inserted perturbation to EOF.
                assert np.array_equal(adv_x[j, beginning_pos:], fix_get_synthetic_data[0][i, beginning_pos:])
                j += 1
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_do_not_check_for_valid(art_warning, fix_get_synthetic_data, fix_make_dummy_model):
    """
    No checking for valid data. Expect a mixed adversarial/normal data to be returned.
    """
    try:
        l0_budget = 1250
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        attack = MalwareGDTensorFlow(
            classifier=fix_make_dummy_model[0],
            embedding_weights=fix_make_dummy_model[1],
            l_0=l0_budget,
            param_dic=param_dic,
        )
        x = np.copy(fix_get_synthetic_data[0])
        y = np.copy(fix_get_synthetic_data[1])
        size_of_files = np.copy(fix_get_synthetic_data[2])

        attack.l_0 = l0_budget

        adv_x = attack.generate(x, y, size_of_files, verify_input_data=False)
        assert len(adv_x) == 5

        # We expect 2 files to have been made adversarial the following cannot be converted to valid adv samples:
        #   2nd datapoint (file too large to support any modifications)
        #   4th datapoint (file to large to support append attacks)
        #   5th datapoint (benign file)
        for i, size in enumerate(size_of_files):
            if i in [0, 2]:
                assert np.array_equal(adv_x[i, :size], fix_get_synthetic_data[0][i, :size])
                assert not np.array_equal(
                    adv_x[i, size : size + l0_budget], fix_get_synthetic_data[0][i, size : size + l0_budget]
                )
                assert np.array_equal(adv_x[i, size + l0_budget :], fix_get_synthetic_data[0][i, size + l0_budget :])
            else:
                assert np.array_equal(adv_x[i], fix_get_synthetic_data[0][i])
    except ARTTestException as e:
        art_warning(e)


@pytest.mark.skip_framework("pytorch", "mxnet", "non_dl_frameworks", "tensorflow1", "keras", "kerastf", "tensorflow2v1")
def test_check_params(art_warning, image_dl_estimator_for_attack):
    try:
        classifier = image_dl_estimator_for_attack(MalwareGDTensorFlow)

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic=[1, 2, 3], embedding_weights=np.array([1, 2, 3]))

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic={"test": 1}, embedding_weights=[1, 2, 3])

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), l_0="1")
        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), l_0=-1)

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), l_r="1")
        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), l_r=-1)

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(
                classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), use_sign="true"
            )

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(
                classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), num_of_iterations=1.0
            )
        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(
                classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), num_of_iterations=-1
            )

        with pytest.raises(ValueError):
            _ = MalwareGDTensorFlow(
                classifier, param_dic={"test": 1}, embedding_weights=np.array([1, 2, 3]), verbose="true"
            )

    except ARTTestException as e:
        art_warning(e)


@pytest.mark.framework_agnostic
def test_classifier_type_check_fail(art_warning, fix_make_dummy_model):
    try:
        param_dic = {"maxlen": 2 ** 20, "input_dim": 257, "embedding_size": 8}
        backend_test_classifier_type_check_fail(
            MalwareGDTensorFlow,
            [BaseEstimator, NeuralNetworkMixin, ClassifierMixin],
            classifier=None,
            param_dic=param_dic,
            embedding_weights=fix_make_dummy_model[1],
        )

    except ARTTestException as e:
        art_warning(e)
